{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "rcParams['figure.figsize'] = (12,6)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchvision import models\n",
    "\n",
    "from torchsummary import summary\n",
    "from ray import tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 135,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_config = dict(\n",
    "    #### model related\n",
    "    device=\"cuda\",\n",
    "#    network=tune.grid_search([\"vgg19_bn_kw\", \"vgg19_bn\"]),\n",
    "    network=\"VGG19\",\n",
    "    num_classes=10,\n",
    "    model=tune.grid_search([\"BaseModel\", \"SparseModel\", \"SET\", \"DSNN\"]),\n",
    "    # model=tune.grid_search([\"SET\", \"DSNN\", \"DSNN_Flip\", \"DSNN_Correct\"]),\n",
    "    dataset_name=\"CIFAR10\",\n",
    "    augment_images=True,\n",
    "    stats_mean=(0.4914, 0.4822, 0.4465),\n",
    "    stats_std=(0.2023, 0.1994, 0.2010),\n",
    "    data_dir=\"~/nta/datasets\",\n",
    "    #### optimizer related\n",
    "    optim_alg=\"SGD\",\n",
    "    momentum=0.9,\n",
    "    learning_rate=0.01,\n",
    "    lr_scheduler=\"MultiStepLR\",\n",
    "    lr_milestones=[250,290],\n",
    "    lr_gamma=0.10,\n",
    "    debug_weights=True,\n",
    "    #### sparse related\n",
    "    epsilon=100,\n",
    "    start_sparse=1,\n",
    "    end_sparse=None, \n",
    "    debug_sparse=True,\n",
    "    flip=False,\n",
    "    weight_prune_perc=0.3,\n",
    "    grad_prune_perc=0,\n",
    "    percent_on=0.3,\n",
    "    boost_strength=1.4,\n",
    "    boost_strength_factor=0.7,\n",
    "    test_noise=True,\n",
    "    noise_level=0.1,\n",
    "    kwinners=True # moved to a parameter\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bound method VGG19._kwinners of VGG19()>\n",
      "<class 'nupic.torch.modules.k_winners.KWinners2d'>\n",
      "KWinners2d(channels=10, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n"
     ]
    }
   ],
   "source": [
    "from networks import VGG19\n",
    "model = VGG19(exp_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### how to replace RELUs with another non-linear activation function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.vgg19_bn()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (2): ReLU(inplace)\n",
       "  (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (5): ReLU(inplace)\n",
       "  (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (9): ReLU(inplace)\n",
       "  (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (12): ReLU(inplace)\n",
       "  (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (16): ReLU(inplace)\n",
       "  (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (19): ReLU(inplace)\n",
       "  (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (22): ReLU(inplace)\n",
       "  (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (25): ReLU(inplace)\n",
       "  (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (29): ReLU(inplace)\n",
       "  (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (32): ReLU(inplace)\n",
       "  (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (35): ReLU(inplace)\n",
       "  (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (38): ReLU(inplace)\n",
       "  (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (42): ReLU(inplace)\n",
       "  (43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (45): ReLU(inplace)\n",
       "  (46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (48): ReLU(inplace)\n",
       "  (49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (51): ReLU(inplace)\n",
       "  (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.vgg19_bn()\n",
    "config = {\n",
    "    'percent_on': 0.3,\n",
    "    'boost_strength': 1.4,\n",
    "    'boost_strength_factor': 0.7\n",
    "}\n",
    "new_features = []\n",
    "for layer in model.features:\n",
    "    # remove max pooling\n",
    "    if isinstance(layer, nn.MaxPool2d):\n",
    "        pass\n",
    "    elif isinstance(layer, nn.Conv2d):\n",
    "        new_features.append(layer)\n",
    "        last_conv_out_channels = layer.out_channels\n",
    "    # switch ReLU to kWinners2d\n",
    "    elif isinstance(layer, nn.ReLU):\n",
    "        new_features.append(KWinners2d(\n",
    "            channels=last_conv_out_channels,\n",
    "            percent_on=config['percent_on'],\n",
    "            boost_strength=config['boost_strength'],\n",
    "            boost_strength_factor=config['boost_strength_factor']))\n",
    "    # otherwise add it as normal\n",
    "    else:\n",
    "        new_features.append(layer)\n",
    "        \n",
    "model.features = nn.Sequential(*new_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGG(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): KWinners2d(channels=64, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (5): KWinners2d(channels=64, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (8): KWinners2d(channels=128, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (11): KWinners2d(channels=128, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (14): KWinners2d(channels=256, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (16): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (17): KWinners2d(channels=256, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (18): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (19): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (20): KWinners2d(channels=256, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (21): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (22): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (23): KWinners2d(channels=256, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (26): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (29): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (32): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (35): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (38): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (39): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (40): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (41): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (43): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (44): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "    (45): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (46): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (47): KWinners2d(channels=512, n=0, percent_on=0.3, boost_strength=1.4, duty_cycle_period=1000)\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
       "    (1): ReLU(inplace)\n",
       "    (2): Dropout(p=0.5)\n",
       "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "    (4): ReLU(inplace)\n",
       "    (5): Dropout(p=0.5)\n",
       "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1           [-1, 64, 32, 32]           1,792\n",
      "       BatchNorm2d-2           [-1, 64, 32, 32]             128\n",
      "        KWinners2d-3           [-1, 64, 32, 32]               0\n",
      "            Conv2d-4           [-1, 64, 32, 32]          36,928\n",
      "       BatchNorm2d-5           [-1, 64, 32, 32]             128\n",
      "        KWinners2d-6           [-1, 64, 32, 32]               0\n",
      "            Conv2d-7          [-1, 128, 32, 32]          73,856\n",
      "       BatchNorm2d-8          [-1, 128, 32, 32]             256\n",
      "        KWinners2d-9          [-1, 128, 32, 32]               0\n",
      "           Conv2d-10          [-1, 128, 32, 32]         147,584\n",
      "      BatchNorm2d-11          [-1, 128, 32, 32]             256\n",
      "       KWinners2d-12          [-1, 128, 32, 32]               0\n",
      "           Conv2d-13          [-1, 256, 32, 32]         295,168\n",
      "      BatchNorm2d-14          [-1, 256, 32, 32]             512\n",
      "       KWinners2d-15          [-1, 256, 32, 32]               0\n",
      "           Conv2d-16          [-1, 256, 32, 32]         590,080\n",
      "      BatchNorm2d-17          [-1, 256, 32, 32]             512\n",
      "       KWinners2d-18          [-1, 256, 32, 32]               0\n",
      "           Conv2d-19          [-1, 256, 32, 32]         590,080\n",
      "      BatchNorm2d-20          [-1, 256, 32, 32]             512\n",
      "       KWinners2d-21          [-1, 256, 32, 32]               0\n",
      "           Conv2d-22          [-1, 256, 32, 32]         590,080\n",
      "      BatchNorm2d-23          [-1, 256, 32, 32]             512\n",
      "       KWinners2d-24          [-1, 256, 32, 32]               0\n",
      "           Conv2d-25          [-1, 512, 32, 32]       1,180,160\n",
      "      BatchNorm2d-26          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-27          [-1, 512, 32, 32]               0\n",
      "           Conv2d-28          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-29          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-30          [-1, 512, 32, 32]               0\n",
      "           Conv2d-31          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-32          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-33          [-1, 512, 32, 32]               0\n",
      "           Conv2d-34          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-35          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-36          [-1, 512, 32, 32]               0\n",
      "           Conv2d-37          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-38          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-39          [-1, 512, 32, 32]               0\n",
      "           Conv2d-40          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-41          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-42          [-1, 512, 32, 32]               0\n",
      "           Conv2d-43          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-44          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-45          [-1, 512, 32, 32]               0\n",
      "           Conv2d-46          [-1, 512, 32, 32]       2,359,808\n",
      "      BatchNorm2d-47          [-1, 512, 32, 32]           1,024\n",
      "       KWinners2d-48          [-1, 512, 32, 32]               0\n",
      "AdaptiveAvgPool2d-49            [-1, 512, 7, 7]               0\n",
      "           Linear-50                 [-1, 4096]     102,764,544\n",
      "             ReLU-51                 [-1, 4096]               0\n",
      "          Dropout-52                 [-1, 4096]               0\n",
      "           Linear-53                 [-1, 4096]      16,781,312\n",
      "             ReLU-54                 [-1, 4096]               0\n",
      "          Dropout-55                 [-1, 4096]               0\n",
      "           Linear-56                 [-1, 1000]       4,097,000\n",
      "================================================================\n",
      "Total params: 143,678,248\n",
      "Trainable params: 143,678,248\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.01\n",
      "Forward/backward pass size (MB): 129.39\n",
      "Params size (MB): 548.09\n",
      "Estimated Total Size (MB): 677.49\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "summary(model, input_size=(3,32,32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nupic.torch.modules import KWinners2d\n",
    "new_features = []\n",
    "for layer in model.features:\n",
    "    if isinstance(layer, nn.MaxPool2d):\n",
    "        pass\n",
    "    elif isinstance(layer, nn.ReLU):\n",
    "        new_features.append(KWinners2d(0, percent_on=0.3))\n",
    "    else:\n",
    "        new_features.append(layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (64) must match the size of tensor b (0) at non-singleton dimension 1",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-46-c069c7653900>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mnew_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torchsummary/torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[0;34m(model, input_size, batch_size, device)\u001b[0m\n\u001b[1;32m     70\u001b[0m     \u001b[0;31m# make a forward pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     71\u001b[0m     \u001b[0;31m# print(x.shape)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 72\u001b[0;31m     \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     74\u001b[0m     \u001b[0;31m# remove these hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    492\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    494\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torchvision/models/vgg.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     41\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 42\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     43\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mavgpool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     44\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    492\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    494\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m     90\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_modules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     93\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    492\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    494\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/nta/nupic.torch/nupic/torch/modules/k_winners.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    267\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    268\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 269\u001b[0;31m             \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mKWinners2d\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mduty_cycle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mboost_strength\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    270\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_duty_cycle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    271\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/nta/nupic.torch/nupic/torch/functions/k_winners.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(ctx, x, duty_cycles, k, boost_strength)\u001b[0m\n\u001b[1;32m    165\u001b[0m             \u001b[0mtarget_density\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m             \u001b[0mboost_factors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget_density\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mduty_cycles\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mboost_strength\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m             \u001b[0mboosted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mboost_factors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    168\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    169\u001b[0m             \u001b[0mboosted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (64) must match the size of tensor b (0) at non-singleton dimension 1"
     ]
    }
   ],
   "source": [
    "model.features = nn.Sequential(*new_features)\n",
    "summary(model, input_size=(3,32,32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGG(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (5): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (8): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (10): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (11): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (14): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (15): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (16): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (17): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (18): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (19): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (20): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (21): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (22): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (23): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (26): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (29): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (32): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (35): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (38): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (39): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (40): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (41): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (42): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (43): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (44): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (45): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (46): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (47): KWinners2d(channels=0, n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
       "    (1): ReLU(inplace)\n",
       "    (2): Dropout(p=0.5)\n",
       "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "    (4): ReLU(inplace)\n",
       "    (5): Dropout(p=0.5)\n",
       "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nupic.torch.modules import KWinners\n",
    "for idx, layer in enumerate(model.features):\n",
    "    \n",
    "    if isinstance(layer, nn.ReLU):\n",
    "        model.features[idx] = KWinners(0, 0.3)\n",
    "    elif is"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "VGG(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (5): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (9): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (12): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (16): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (19): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (22): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (24): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (25): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (26): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (27): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (29): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (32): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (33): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (34): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (35): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (36): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (37): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (38): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (39): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (42): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (43): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (44): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (45): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (46): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (47): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (48): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (49): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "    (50): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (51): KWinners(n=0, percent_on=0.3, boost_strength=1.0, duty_cycle_period=1000)\n",
       "    (52): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
       "  (classifier): Sequential(\n",
       "    (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
       "    (1): ReLU(inplace)\n",
       "    (2): Dropout(p=0.5)\n",
       "    (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "    (4): ReLU(inplace)\n",
       "    (5): Dropout(p=0.5)\n",
       "    (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__call__',\n",
       " '__class__',\n",
       " '__constants__',\n",
       " '__delattr__',\n",
       " '__dict__',\n",
       " '__dir__',\n",
       " '__doc__',\n",
       " '__eq__',\n",
       " '__format__',\n",
       " '__ge__',\n",
       " '__getattr__',\n",
       " '__getattribute__',\n",
       " '__gt__',\n",
       " '__hash__',\n",
       " '__init__',\n",
       " '__init_subclass__',\n",
       " '__le__',\n",
       " '__lt__',\n",
       " '__module__',\n",
       " '__ne__',\n",
       " '__new__',\n",
       " '__reduce__',\n",
       " '__reduce_ex__',\n",
       " '__repr__',\n",
       " '__setattr__',\n",
       " '__setstate__',\n",
       " '__sizeof__',\n",
       " '__str__',\n",
       " '__subclasshook__',\n",
       " '__weakref__',\n",
       " '_apply',\n",
       " '_backend',\n",
       " '_backward_hooks',\n",
       " '_buffers',\n",
       " '_forward_hooks',\n",
       " '_forward_pre_hooks',\n",
       " '_get_name',\n",
       " '_load_from_state_dict',\n",
       " '_load_state_dict_pre_hooks',\n",
       " '_modules',\n",
       " '_named_members',\n",
       " '_parameters',\n",
       " '_register_load_state_dict_pre_hook',\n",
       " '_register_state_dict_hook',\n",
       " '_slow_forward',\n",
       " '_state_dict_hooks',\n",
       " '_tracing_name',\n",
       " '_version',\n",
       " 'add_module',\n",
       " 'apply',\n",
       " 'bias',\n",
       " 'buffers',\n",
       " 'children',\n",
       " 'cpu',\n",
       " 'cuda',\n",
       " 'double',\n",
       " 'dump_patches',\n",
       " 'eval',\n",
       " 'extra_repr',\n",
       " 'float',\n",
       " 'forward',\n",
       " 'half',\n",
       " 'in_features',\n",
       " 'load_state_dict',\n",
       " 'modules',\n",
       " 'named_buffers',\n",
       " 'named_children',\n",
       " 'named_modules',\n",
       " 'named_parameters',\n",
       " 'out_features',\n",
       " 'parameters',\n",
       " 'register_backward_hook',\n",
       " 'register_buffer',\n",
       " 'register_forward_hook',\n",
       " 'register_forward_pre_hook',\n",
       " 'register_parameter',\n",
       " 'reset_parameters',\n",
       " 'share_memory',\n",
       " 'state_dict',\n",
       " 'to',\n",
       " 'train',\n",
       " 'training',\n",
       " 'type',\n",
       " 'weight',\n",
       " 'zero_grad']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class cauchy_activation(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(cauchy_activation, self).__init__()\n",
    "        \n",
    "    def activation(self, inp):\n",
    "        return pdf_cauchy_distribution(self.inp)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
